#include <gtest/gtest.h>
#include <mockcpp/mockcpp.hpp>
#include "graph/debug/ge_error_codes.h"
#include "framework/graph/debug/ge_util.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/cgraph/graph_modifier.h"
#define protected public
#define private public
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/utils/attr_utils.h"
#include "framework/graph/utils/op_desc_utils.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/infershape/op_ir_func_factory.h"
#include "framework/graph/core/infershape/op_ir_facade.h"
#include "graph/op/nn_defs.h"
#include "graph/op/internal_defs.h"
#include "framework/graph/core/infershape/graph_infershape_util.h"
#include <vector>
#undef protected
#undef private
using namespace ge;
using namespace testing;
using namespace std;
using namespace hiai;

class UTESTGeGraphInfershape : public testing::Test {
protected:
    void SetUp()
    {
    }

    void TearDown()
    {
        GlobalMockObject::verify();
    }
};

TEST_F(UTESTGeGraphInfershape, prelu_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);

    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);

    OpDescPtr preluOpDescPtr = std::make_shared<OpDesc>("PReLU", "PReLU");
    Node* preluNodePtr = graphPtr->ROLE(GraphModifier).AddNode(preluOpDescPtr);

    TensorDesc tensorDesc1(Shape({1, 4, 3, 3}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr->AddInputDesc(tensorDesc1);
    dataOpdescPtr->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    preluOpDescPtr->AddInputDesc(tensorDesc1);
    preluOpDescPtr->AddInputDesc(tensorDesc2);
    preluOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*preluNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*preluNodePtr, 1});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = preluNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(1, out.GetDim(0));
    EXPECT_EQ(4, out.GetDim(1));
    EXPECT_EQ(3, out.GetDim(2));
    EXPECT_EQ(3, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_sin_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");

    TensorDesc tensorDesc1(Shape({1, 4, 3, 3}));
    tensorDesc1.SetDataType(DT_FLOAT);
    TensorDesc outDesc;
    OpDescPtr sinOpDescPtr = std::make_shared<OpDesc>("Sin", "Sin");
    sinOpDescPtr->AddInputDesc(tensorDesc1);
    sinOpDescPtr->AddOutputDesc(outDesc);
    Node* sinNodePtr = graphPtr->ROLE(GraphModifier).AddNode(sinOpDescPtr);

    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    dataOpdescPtr->AddInputDesc(tensorDesc1);
    dataOpdescPtr->AddOutputDesc(tensorDesc1);
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*sinNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = sinNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(1, out.GetDim(0));
    EXPECT_EQ(4, out.GetDim(1));
    EXPECT_EQ(3, out.GetDim(2));
    EXPECT_EQ(3, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_cos_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");

    TensorDesc tensorDesc1(Shape({1, 4, 3, 3}));
    tensorDesc1.SetDataType(DT_FLOAT);
    TensorDesc outDesc;
    OpDescPtr cosOpDescPtr = std::make_shared<OpDesc>("Cos", "Cos");
    cosOpDescPtr->AddInputDesc(tensorDesc1);
    cosOpDescPtr->AddOutputDesc(outDesc);
    Node* cosNodePtr = graphPtr->ROLE(GraphModifier).AddNode(cosOpDescPtr);

    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    dataOpdescPtr->AddInputDesc(tensorDesc1);
    dataOpdescPtr->AddOutputDesc(tensorDesc1);
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*cosNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = cosNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(1, out.GetDim(0));
    EXPECT_EQ(4, out.GetDim(1));
    EXPECT_EQ(3, out.GetDim(2));
    EXPECT_EQ(3, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_tan_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");

    TensorDesc tensorDesc1(Shape({1, 4, 3, 3}));
    tensorDesc1.SetDataType(DT_FLOAT);
    TensorDesc outDesc;
    OpDescPtr tanOpDescPtr = std::make_shared<OpDesc>("Tan", "Tan");
    tanOpDescPtr->AddInputDesc(tensorDesc1);
    tanOpDescPtr->AddOutputDesc(outDesc);
    Node* tanNodePtr = graphPtr->ROLE(GraphModifier).AddNode(tanOpDescPtr);

    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    dataOpdescPtr->AddInputDesc(tensorDesc1);
    dataOpdescPtr->AddOutputDesc(tensorDesc1);
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*tanNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = tanNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(1, out.GetDim(0));
    EXPECT_EQ(4, out.GetDim(1));
    EXPECT_EQ(3, out.GetDim(2));
    EXPECT_EQ(3, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_fractional_pooling_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);
    OpDescPtr tempOpDescPtr = std::make_shared<OpDesc>("FractionalPooling", "FractionalPooling");
    Node* tempNodePtr = graphPtr->ROLE(GraphModifier).AddNode(tempOpDescPtr);

    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::mode, AttrValue::CreateFrom(static_cast<AttrValue::INT>(0)));
    vector<float> v = {1.0, 1.5, 1.6, 1.0};
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::pooling_ratio, AttrValue::CreateFrom(v));
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::pseudo_random, AttrValue::CreateFrom(false));
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::overlapping, AttrValue::CreateFrom(false));
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::deterministic, AttrValue::CreateFrom(false));
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::seed, AttrValue::CreateFrom(static_cast<AttrValue::INT>(0)));
    tempOpDescPtr->SetAttr(hiai::op::FractionalPooling::seed2, AttrValue::CreateFrom(static_cast<AttrValue::INT>(1)));

    TensorDesc tensorDesc(Shape({2, 3, 4, 5}));
    tensorDesc.SetDataType(DT_FLOAT);
    dataOpdescPtr->AddInputDesc(tensorDesc);
    dataOpdescPtr->AddOutputDesc(tensorDesc);

    TensorDesc outDesc;
    tempOpDescPtr->AddInputDesc(tensorDesc);
    tempOpDescPtr->AddOutputDesc(outDesc);
    tempOpDescPtr->AddOutputDesc(outDesc);
    tempOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*tempNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output0 = tempNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    EXPECT_EQ(output0.GetDataType(), DT_FLOAT);
    EXPECT_EQ(output0.GetShape().GetDim(0), 2);
    EXPECT_EQ(output0.GetShape().GetDim(1), 2);
    EXPECT_EQ(output0.GetShape().GetDim(2), 2);
    EXPECT_EQ(output0.GetShape().GetDim(3), 5);

    TensorDesc output1 = tempNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(1);
    EXPECT_EQ(output1.GetDataType(), DT_INT64);
    EXPECT_EQ(output1.GetShape().GetDim(0), 3);

    TensorDesc output2 = tempNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(2);
    EXPECT_EQ(output2.GetDataType(), DT_INT64);
    EXPECT_EQ(output2.GetShape().GetDim(0), 3);
}

TEST_F(UTESTGeGraphInfershape, graph_random_normal_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr randomNormalOpDescPtr = std::make_shared<OpDesc>("RandomNormalNoSeed", "RandomNormalNoSeed");
    Node* randomNormalNodePtr = graphPtr->ROLE(GraphModifier).AddNode(randomNormalOpDescPtr);

    TensorDesc outDesc;
    randomNormalOpDescPtr->AddOutputDesc(outDesc);

    OpDescPtr dataOpdesc = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdesc);

    int data3[4] = {2, 3, 4, 1};
    TensorDesc tensorDesc3(Shape({4}));
    tensorDesc3.SetDataType(DT_INT32);
    TensorPtr tensor3 = std::make_shared<Tensor>(tensorDesc3, (const uint8_t*)data3, 4 * sizeof(int));
    float data4[1] = {0.0f};
    TensorDesc tensorDesc4(Shape({1}));
    tensorDesc4.SetDataType(DT_FLOAT);
    TensorPtr tensor4 = std::make_shared<Tensor>(tensorDesc4, (const uint8_t*)data4, 1 * sizeof(float));
    float data5[1] = {1.0f};
    TensorDesc tensorDesc5(Shape({1}));
    tensorDesc5.SetDataType(DT_FLOAT);
    TensorPtr tensor5 = std::make_shared<Tensor>(tensorDesc5, (const uint8_t*)data5, 1 * sizeof(float));
    vector<ge::TensorPtr> weights;
    weights.push_back(tensor3);
    weights.push_back(tensor4);
    weights.push_back(tensor5);

    OpDescUtils::SetWeights(*randomNormalNodePtr, weights);

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = randomNormalNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(3, out.GetDim(1));
    EXPECT_EQ(4, out.GetDim(2));
    EXPECT_EQ(1, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_random_shuffle_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);
    OpDescPtr randomShuffleOpDescPtr = std::make_shared<OpDesc>("RandomShuffleNoSeed", "RandomShuffleNoSeed");
    Node* randomShuffleNodePtr = graphPtr->ROLE(GraphModifier).AddNode(randomShuffleOpDescPtr);

    TensorDesc tensorDesc(Shape({4, 1, 3, 2}));
    tensorDesc.SetDataType(DT_INT32);
    dataOpdescPtr->AddInputDesc(tensorDesc);
    dataOpdescPtr->AddOutputDesc(tensorDesc);

    TensorDesc outDesc;
    randomShuffleOpDescPtr->AddInputDesc(tensorDesc);
    randomShuffleOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*randomShuffleNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = randomShuffleNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(4, out.GetDim(0));
    EXPECT_EQ(1, out.GetDim(1));
    EXPECT_EQ(3, out.GetDim(2));
    EXPECT_EQ(2, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition4Dim_4D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({2, 1, 2, 4}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 1, 2, 4}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({2, 1, 2, 4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = selectNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(1, out.GetDim(1));
    EXPECT_EQ(2, out.GetDim(2));
    EXPECT_EQ(4, out.GetDim(3));
    OpIRFacade opIRFacade(*dataNodePtr0);
    std::vector<ge::TensorDesc> inputs;
    EXPECT_EQ(ge::GRAPH_SUCCESS, opIRFacade.GetInputs(inputs));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition4Dim_4D_infershape_failed)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({2, 1, 2, 4}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 1, 2, 4}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({1, 1, 2, 4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_FAILED, GraphInfershapeUtil::InferShape(*graphPtr));
    OpIRFacade opIRFacade(*dataNodePtr0);
    string subGraphAttr;
    ComputeGraphPtr subGraph;
    EXPECT_EQ(nullptr, opIRFacade.GetSubGraph(subGraphAttr, subGraph));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition_4Dim_2D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({2, 1}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 1}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({2, 1}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = selectNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(1, out.GetDim(1));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition_1Dim_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({4}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 1, 2, 4}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({2, 1, 2, 4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = selectNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(1, out.GetDim(1));
    EXPECT_EQ(2, out.GetDim(2));
    EXPECT_EQ(4, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition_1Dim_2D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({4}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 4}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({2, 4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = selectNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(4, out.GetDim(1));
}

TEST_F(UTESTGeGraphInfershape, graph_select_condition_1Dim_2D_infershape_failed)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);
    OpDescPtr dataOpdescPtr1 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr1 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr1);
    OpDescPtr dataOpdescPtr2 = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr2 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr2);
    OpDescPtr selectOpDescPtr = std::make_shared<OpDesc>("Select", "Select");
    Node* selectNodePtr = graphPtr->ROLE(GraphModifier).AddNode(selectOpDescPtr);

    TensorDesc tensorDesc0(Shape({5}));
    tensorDesc0.SetDataType(DT_BOOL);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);

    TensorDesc tensorDesc1(Shape({2, 4}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr1->AddInputDesc(tensorDesc1);
    dataOpdescPtr1->AddOutputDesc(tensorDesc1);

    TensorDesc tensorDesc2(Shape({2, 4}));
    tensorDesc2.SetDataType(DT_FLOAT);
    dataOpdescPtr2->AddInputDesc(tensorDesc2);
    dataOpdescPtr2->AddOutputDesc(tensorDesc2);

    TensorDesc outDesc;
    selectOpDescPtr->AddInputDesc(tensorDesc0);
    selectOpDescPtr->AddInputDesc(tensorDesc1);
    selectOpDescPtr->AddInputDesc(tensorDesc2);
    selectOpDescPtr->AddOutputDesc(outDesc);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*selectNodePtr, 0});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr1, 0}, {*selectNodePtr, 1});
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr2, 0}, {*selectNodePtr, 2});

    EXPECT_EQ(ge::GRAPH_FAILED, GraphInfershapeUtil::InferShape(*graphPtr));
}

TEST_F(UTESTGeGraphInfershape, graph_topk_4D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    TensorDesc tensorDesc0(Shape({2, 1, 2, 4}));
    tensorDesc0.SetDataType(DT_FLOAT);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);

    OpDescPtr topKOpDescPtr = std::make_shared<OpDesc>("TopK", "TopK");
    TensorDesc outDesc0;
    TensorDesc outDesc1;
    topKOpDescPtr->AddInputDesc(tensorDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc1);
    Node* topKNodePtr = graphPtr->ROLE(GraphModifier).AddNode(topKOpDescPtr);
    topKOpDescPtr->SetAttr("sorted", AttrValue::CreateFrom(false));

    int32_t k = 2;
    TensorDesc tensorDesc1(Shape({1}));
    tensorDesc1.SetDataType(DT_INT32);
    TensorPtr tensor1 = std::make_shared<Tensor>(tensorDesc1, (const uint8_t*)&k, 1 * sizeof(int32_t));
    vector<ge::TensorPtr> weights;
    weights.push_back(tensor1);
    OpDescUtils::SetWeights(*topKNodePtr, weights);
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*topKNodePtr, 0});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = topKNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(1, out.GetDim(1));
    EXPECT_EQ(2, out.GetDim(2));
    EXPECT_EQ(2, out.GetDim(3));
}

TEST_F(UTESTGeGraphInfershape, graph_topk_2D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr topKOpDescPtr = std::make_shared<OpDesc>("TopK", "TopK");
    topKOpDescPtr->SetAttr("sorted", AttrValue::CreateFrom(false));

    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    TensorDesc tensorDesc0(Shape({2, 4}));
    tensorDesc0.SetDataType(DT_FLOAT);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);

    TensorDesc outDesc0;
    TensorDesc outDesc1;
    topKOpDescPtr->AddInputDesc(tensorDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc1);
    Node* topKNodePtr = graphPtr->ROLE(GraphModifier).AddNode(topKOpDescPtr);

    int32_t k = 2;
    TensorDesc tensorDesc1(Shape({1}));
    tensorDesc1.SetDataType(DT_INT32);
    TensorPtr tensor1 = std::make_shared<Tensor>(tensorDesc1, (const uint8_t*)&k, 1 * sizeof(int32_t));
    vector<ge::TensorPtr> weights;
    weights.push_back(tensor1);
    OpDescUtils::SetWeights(*topKNodePtr, weights);
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*topKNodePtr, 0});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = topKNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
    EXPECT_EQ(2, out.GetDim(1));
}

TEST_F(UTESTGeGraphInfershape, graph_topk_1D_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr topKOpDescPtr = std::make_shared<OpDesc>("TopK", "TopK");
    topKOpDescPtr->SetAttr("sorted", AttrValue::CreateFrom(false));

    OpDescPtr dataOpdescPtr0 = std::make_shared<OpDesc>("Data", "Data");
    TensorDesc tensorDesc0(Shape({2}));
    tensorDesc0.SetDataType(DT_FLOAT);
    dataOpdescPtr0->AddInputDesc(tensorDesc0);
    dataOpdescPtr0->AddOutputDesc(tensorDesc0);
    Node* dataNodePtr0 = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr0);

    TensorDesc outDesc0;
    TensorDesc outDesc1;
    topKOpDescPtr->AddInputDesc(tensorDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc0);
    topKOpDescPtr->AddOutputDesc(outDesc1);
    Node* topKNodePtr = graphPtr->ROLE(GraphModifier).AddNode(topKOpDescPtr);

    int32_t k = 2;
    TensorDesc tensorDesc1(Shape({1}));
    tensorDesc1.SetDataType(DT_INT32);
    TensorPtr tensor1 = std::make_shared<Tensor>(tensorDesc1, (const uint8_t*)&k, 1 * sizeof(int32_t));
    vector<ge::TensorPtr> weights;
    weights.push_back(tensor1);
    OpDescUtils::SetWeights(*topKNodePtr, weights);
    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr0, 0}, {*topKNodePtr, 0});

    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));
    TensorDesc output = topKNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);
    Shape out = output.GetShape();
    EXPECT_EQ(2, out.GetDim(0));
}

namespace {
struct PriorboxAttr {
    vector<float> minsizes;
    vector<float> maxsizes;
    vector<float> aspectratios;
    bool flip;
    int numpriors;
    bool clip;
    vector<float> variance;
    int imgw;
    int imgh;
    int layerwidth;
    int layerheight;
    float stepw;
    float steph;
    float offset;
};

void SetPriorboxParameter(
    OpDescPtr& priorboxOpDescPtr, OpDescPtr& dataOpdescPtr, PriorboxAttr& priorboxPara, bool addImg = false)
{
    ge::AttrUtils::SetListFloat(priorboxOpDescPtr, "min_size", priorboxPara.minsizes);
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "min_size_num", priorboxPara.minsizes.size());
    ge::AttrUtils::SetListFloat(priorboxOpDescPtr, "max_size", priorboxPara.maxsizes);
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "max_size_num", priorboxPara.maxsizes.size());
    ge::AttrUtils::SetListFloat(priorboxOpDescPtr, "aspect_ratio", priorboxPara.aspectratios);
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "aspect_ratio_num", priorboxPara.aspectratios.size());
    ge::AttrUtils::SetListFloat(priorboxOpDescPtr, "variance", priorboxPara.variance);
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "variance_num", priorboxPara.variance.size());
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "img_h", priorboxPara.imgh);
    ge::AttrUtils::SetInt(priorboxOpDescPtr, "img_w", priorboxPara.imgw);
    ge::AttrUtils::SetBool(priorboxOpDescPtr, "flip", priorboxPara.flip);
    ge::AttrUtils::SetBool(priorboxOpDescPtr, "clip", priorboxPara.clip);
    ge::AttrUtils::SetFloat(priorboxOpDescPtr, "step_h", priorboxPara.steph);
    ge::AttrUtils::SetFloat(priorboxOpDescPtr, "step_w", priorboxPara.stepw);
    ge::AttrUtils::SetFloat(priorboxOpDescPtr, "offset", priorboxPara.offset);

    TensorDesc tensorDesc1(Shape({1, 16, priorboxPara.layerheight, priorboxPara.layerwidth}));
    tensorDesc1.SetDataType(DT_FLOAT);
    dataOpdescPtr->AddInputDesc(tensorDesc1);
    dataOpdescPtr->AddOutputDesc(tensorDesc1);

    TensorDesc outDesc;
    priorboxOpDescPtr->AddInputDesc(tensorDesc1);
    if (addImg) {
        TensorDesc tensorDesc2(Shape({1, 16, 600, 600}));
        tensorDesc2.SetDataType(DT_FLOAT);
        priorboxOpDescPtr->AddInputDesc(tensorDesc2);
    }
    priorboxOpDescPtr->AddOutputDesc(outDesc);
}
} // namespace

TEST_F(UTESTGeGraphInfershape, graph_priorbox_infershape_success)
{
    std::shared_ptr<ComputeGraph> graphPtr = ge::ComputeGraph::Make("ut_test_graph");
    OpDescPtr dataOpdescPtr = std::make_shared<OpDesc>("Data", "Data");
    Node* dataNodePtr = graphPtr->ROLE(GraphModifier).AddNode(dataOpdescPtr);

    OpDescPtr priorboxOpDescPtr = std::make_shared<OpDesc>("SSDPriorBox", "SSDPriorBox");
    Node* priorboxNodePtr = graphPtr->ROLE(GraphModifier).AddNode(priorboxOpDescPtr);

    PriorboxAttr priorboxPara = {{30}, {60}, {2}, true, 0, false, {0.1}, 600, 600, 75, 75, 0, 0, 0.5};

    SetPriorboxParameter(priorboxOpDescPtr, dataOpdescPtr, priorboxPara);

    (void)graphPtr->ROLE(GraphModifier).AddEdge({*dataNodePtr, 0}, {*priorboxNodePtr, 0});
    EXPECT_EQ(ge::GRAPH_SUCCESS, GraphInfershapeUtil::InferShape(*graphPtr));

    TensorDesc output = priorboxNodePtr->ROLE(NodeSpec).OpDesc().GetOutputDesc(0);

    Shape out = output.GetShape();
    EXPECT_EQ(1, out.GetDim(0));
    EXPECT_EQ(2, out.GetDim(1));
    EXPECT_EQ(90000, out.GetDim(2));
    EXPECT_EQ(1, out.GetDim(3));
}